/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#ifndef SRC_PROTOCOL_H_
#define SRC_PROTOCOL_H_

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#if defined(RUBY_BINDING)
#include <ruby.h>
#endif

#include <sys/types.h>
#if defined(IS_OPENBSD)
#include <netinet/in_systm.h>
#include <net/ethertypes.h>
#else
#include <net/ethernet.h>
#endif

#include <iostream>
#include <fstream>
#include <chrono>
#include <iomanip>
#include <boost/utility/string_ref.hpp>
#include "UnitConverter.h"
#include "Pointer.h"
#include "FlowForwarder.h"
#include "Multiplexer.h"
#include "DatabaseAdaptor.h"
#include "ipset/IPSetManager.h"
#include "names/DomainNameManager.h"
#include "CounterMap.h"
#include "Cache.h"
#include "Message.h"
#include "Logger.h"
#include "ElapsedTime.h"
#include "StringMap.h"
#include "CPUCycle.h"

namespace aiengine {

class Flow;

#if defined(JAVA_BINDING)

typedef std::map<std::string, int32_t> JavaCounters;

#elif defined(RUBY_BINDING)

typedef struct ruby_shared_data {
        VALUE obj;
	ID method_id;
	int nargs;
	VALUE args[4];
} ruby_shared_data;

#elif defined(LUA_BINDING)

typedef std::map<std::string, int32_t> LuaCounters;

#endif

class Protocol {
public:
	explicit Protocol(const std::string &name, uint16_t protocol_layer);
	explicit Protocol(const std::string &name);
    	virtual ~Protocol();

	virtual uint16_t getId() const = 0;
	virtual uint16_t getHeaderSize() const = 0;
	virtual void setHeader(const uint8_t *raw_packet) = 0;

	virtual void statistics(std::basic_ostream<char> &out, int level, int32_t limit) const = 0;
	virtual void statistics(Json &out, int level) const = 0;
	virtual void statistics(Json &out, const std::string &map_name, int32_t limit) const {}

	void setStatisticsLevel(int level) { stats_level_ = level; }

        uint64_t getTotalBytes()  const { return total_bytes_; }
        uint64_t getTotalPackets() const { return total_packets_; }
        uint64_t getTotalValidPackets() const { return total_valid_packets_; }
        uint64_t getTotalInvalidPackets() const { return total_invalid_packets_; }

	const char* name() const { return name_.c_str(); }

	bool active() const { return is_active_; }
	void active(bool value) { is_active_ = value; }

	virtual bool check(const Packet &packet) = 0;
	virtual void processFlow(Flow *flow) = 0;
	virtual bool processPacket(Packet &packet) = 0;

	void setMultiplexer(MultiplexerPtrWeak mux) { mux_ = mux; }
	MultiplexerPtrWeak getMultiplexer() const { return mux_; }

	void setFlowForwarder(WeakPointer<FlowForwarder> ff) { flow_forwarder_ = ff; }
	WeakPointer<FlowForwarder> getFlowForwarder() const { return flow_forwarder_; }

	uint16_t getProtocolLayer() const { return protocol_layer_; }

	void infoMessage(const std::string& msg);

	// Reset the internal statisitics of the protocol
	virtual void resetCounters() = 0;
	void reset();

	// Clear cache resources
	virtual void releaseCache() = 0;

	// Memory comsumption of the Protocol, caches items.
	virtual uint64_t getAllocatedMemory() const = 0;
	// Memory comsumption of all the memory used
	virtual uint64_t getTotalAllocatedMemory() const = 0;
	// current memory used by the caches
	virtual uint64_t getCurrentUseMemory() const = 0;
	uint64_t getTotalCpuCycles() const { return total_cpu_cycles_; }

	virtual void setDynamicAllocatedMemory(bool value) = 0;
	virtual bool isDynamicAllocatedMemory() const = 0;

	// used on mainly on the bindings
	virtual void increaseAllocatedMemory(int value) {}
	virtual void decreaseAllocatedMemory(int value) {}

        virtual void setDomainNameManager(const SharedPointer<DomainNameManager> &dnm) {} // Non pure virtual methods
        virtual void setDomainNameBanManager(const SharedPointer<DomainNameManager> &dnm) {}

	virtual void releaseFlowInfo(Flow *flow) {}

#if defined(HAVE_REJECT_FLOW)
	virtual void addRejectFunction(std::function <void (Flow*)> reject) {}
#endif
	virtual void setAnomalyManager(SharedPointer<AnomalyManager> amng) {}

	virtual uint32_t getTotalCacheMisses() const { return 0; }
	virtual uint32_t getTotalEvents() const { return 0; }

	virtual CounterMap getCounters() const = 0;

#if defined(PYTHON_BINDING)
	virtual boost::python::dict getCacheData(const std::string &name) const { return boost::python::dict(); }
	virtual SharedPointer<Cache<StringCache>> getCache(const std::string &name) { return nullptr; }

	void setDatabaseAdaptor(boost::python::object &dbptr, int packet_sampling);

	boost::python::dict addMapToHash(const StringMap &mt, const char *header = "") const;

	void setOnFailCacheCallback(PyObject *callback);
        PyObject *getOnFailCacheCallback() const { return cache_fail_callback_.getCallback(); }

#elif defined(RUBY_BINDING)
	virtual VALUE getCacheData(const std::string &name) const { return Qnil; }
	void setDatabaseAdaptor(VALUE dbptr, int packet_sampling);

	VALUE addMapToHash(const StringMap &mt, const char *header = "") const;
#elif defined(JAVA_BINDING) || defined(GO_BINDING)
	void setDatabaseAdaptor(DatabaseAdaptor *dbptr, int packet_sampling);
#elif defined(LUA_BINDING)
	void setDatabaseAdaptor(lua_State *L, const char *obj_name, int packet_sampling);
#endif

#if defined(BINDING)

	bool getDatabaseObjectIsSet() const { return is_set_db_;}
	int getPacketSampling() const { return packet_sampling_;}

	void databaseAdaptorInsertHandler(Flow *flow);
	void databaseAdaptorUpdateHandler(Flow *flow);
	void databaseAdaptorRemoveHandler(Flow *flow);
#endif
	void setIPSetManager(const SharedPointer<IPSetManager> ipset_mng);

	SharedPointer<IPSetManager> ipset_mng_ = nullptr;
        MultiplexerPtrWeak mux_ = MultiplexerPtrWeak();
        WeakPointer<FlowForwarder> flow_forwarder_ = WeakPointer<FlowForwarder>();
protected:
	uint32_t releaseStringToCache(Cache<StringCache>::CachePtr &cache, const SharedPointer<StringCache> &item);

	void showStatisticsHeader(Json &out, int level) const;
	void showStatisticsHeader(std::basic_ostream<char> &out, int level) const;
	void logFailCache(std::string_view name, Flow *flow);
	void logFailCache(std::string_view name);

	void computeMemoryUtilization(std::ostringstream &out, uint64_t tcu, uint64_t trf, uint64_t tcs) const;

	uint64_t total_valid_packets_ = 0;
	uint64_t total_invalid_packets_ = 0;
	uint64_t total_packets_ = 0;
	uint64_t total_bytes_ = 0;
	uint64_t total_cpu_cycles_ = 0;
	int stats_level_ = 0;
	static const int max_seconds_between_cache_fails = 5;
	boost::posix_time::ptime data_time_ = boost::posix_time::microsec_clock::local_time();
private:
	std::string name_ = "";
	uint16_t protocol_layer_ = 0; // TCP or UDP
	std::time_t last_cache_log_fail_ = 0;
	bool is_active_ = false;
#if defined(BINDING)
	std::ostringstream key_ {};
	std::ostringstream data_ {};
	int packet_sampling_ = 32;
        bool is_set_db_ = false;
#if defined(PYTHON_BINDING)
        boost::python::object dbptr_ {};
	Callback cache_fail_callback_;
#elif defined(RUBY_BINDING)
	VALUE dbptr_ = Qnil;
#elif defined(JAVA_BINDING) || defined(GO_BINDING)
	DatabaseAdaptor *dbptr_ = nullptr;
#elif defined(LUA_BINDING)
	lua_State *L_ = nullptr;
	int ref_function_insert_ = LUA_NOREF;
	int ref_function_update_ = LUA_NOREF;
	int ref_function_remove_ = LUA_NOREF;
#endif
#endif // defined(BINDING)
};

typedef std::shared_ptr <Protocol> ProtocolPtr;
typedef std::weak_ptr <Protocol> ProtocolPtrWeak;

} // namespace aiengine

#endif  // SRC_PROTOCOL_H_
